Our task is to predict flu infections based on the methods described here.
We used dataset Flu cases from 1997 through 2021 by week.
import datetime
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
from scipy.optimize import least_squares
from scipy.signal import find_peaks
TODO: zmienić import na ten z ibm cloud
flu_data = pd.read_excel('flu_cases_1997_2021_raw_data.xlsx')
flu_data.head()
| YEAR | WEEK | Total_cases | |
|---|---|---|---|
| 0 | 1997 | 40 | 0 |
| 1 | 1997 | 41 | 11 |
| 2 | 1997 | 42 | 17 |
| 3 | 1997 | 43 | 8 |
| 4 | 1997 | 44 | 10 |
# we want to obtain date from columns YEAR and WEEK
flu_data['date'] = pd.to_datetime(flu_data.YEAR.astype(str), format='%Y') + \
pd.to_timedelta(flu_data.WEEK.mul(7).astype(str) + ' days')
flu_data.head()
| YEAR | WEEK | Total_cases | date | |
|---|---|---|---|---|
| 0 | 1997 | 40 | 0 | 1997-10-08 |
| 1 | 1997 | 41 | 11 | 1997-10-15 |
| 2 | 1997 | 42 | 17 | 1997-10-22 |
| 3 | 1997 | 43 | 8 | 1997-10-29 |
| 4 | 1997 | 44 | 10 | 1997-11-05 |
# we add additional column for day count
days = list(range(0, 7*len(flu_data), 7))
flu_data['days'] = days
flu_data.head()
| YEAR | WEEK | Total_cases | date | days | |
|---|---|---|---|---|---|
| 0 | 1997 | 40 | 0 | 1997-10-08 | 0 |
| 1 | 1997 | 41 | 11 | 1997-10-15 | 7 |
| 2 | 1997 | 42 | 17 | 1997-10-22 | 14 |
| 3 | 1997 | 43 | 8 | 1997-10-29 | 21 |
| 4 | 1997 | 44 | 10 | 1997-11-05 | 28 |
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.show()
We can see that data since April 2020 differ from the previous years -- it is probably because of the COVID-19 pandemics. We've decided to delete that data from our dataset.
flu_data = flu_data[~ flu_data['YEAR'].isin([2020, 2021])]
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.show()
# peaks - indexes of peaks
peaks, _ = find_peaks(flu_data['Total_cases'], distance=24)
x = flu_data['Total_cases']
fig = px.line(flu_data, x='date', y="Total_cases", hover_data=['date'])
fig.add_scatter(x=flu_data['date'].iloc[peaks], y=x[peaks], mode="markers", marker_symbol='x', marker=dict(size=8), name='local maxima')
fig.show()
# choosing dataframe only for peaks
peaks_flu = flu_data.iloc[peaks]
def PolyCoefficients(x, coeffs):
o = len(coeffs)
y = 0
for i in reversed(list(range(o))):
y += coeffs[i]*x**i
y += coeffs[-1]
return y
x = flu_data['days']
p_1 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=1)
np.polyval(p_1, list(peaks_flu['days'])[-1] + 365)
3767.5588269873924
p_2 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=2)
p_3 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=3)
p_4 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=4)
p_5 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=5)
p_6 = np.polyfit(x=peaks_flu['days'], y=peaks_flu['Total_cases'], deg=6)
from plotly.subplots import make_subplots
import plotly.graph_objects as go
fig = make_subplots(rows=3, cols=2)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_1), name='deg=1'),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_2), name='deg=2'),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_3), name='deg=3'),
row=2, col=1
)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_4), name='deg=4'),
row=2, col=2
)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_5), name='deg=5'),
row=3, col=1
)
fig.add_trace(
go.Scatter(x=x, y=PolyCoefficients(x, p_6), name='deg=6'),
row=3, col=2
)
fig.update_layout(title_text="Polynomial plots for different degree of polynomial")
fig.show()
def KMcK(S1, I1, R1, n, b=0.434, a=0.4):
S = [S1]
I = [I1]
R = [R1]
for i in range(n):
S.append(S[i] - b * I[i] * S[i])
I.append(I[i] + b * I[i] * S[i] - a * I[i])
R.append(R[i] + a * I[i])
return S, I, R
KMcK(
S1=1000,
I1=11,
R1=0,
n=10,
)
([1000, -3774.0, 7826447.2296, 26587005881167.25, 3.0678105855204267e+26, 4.0845744162657916e+52, 7.240746702313656e+104, 2.275393115826668e+209, inf, inf, nan], [11, 4780.6, -7827352.869600001, -26587002751131.742, -3.06781058552032e+26, -4.0845744162657916e+52, -7.240746702313656e+104, -2.275393115826668e+209, -inf, nan, nan], [0, 4.4, 1916.6400000000003, -3129024.5078400006, -10634804229477.205, -1.2271242342082344e+26, -1.6338297665063168e+52, -2.8962986809254624e+104, -9.101572463306672e+208, -inf, nan])